{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Knet RNN example\n",
    "**TODO**: Use the new RNN interface, add dropout?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Pkg; haskey(Pkg.installed(),\"Knet\") || Pkg.add(\"Knet\")\n",
    "using Knet\n",
    "True=true # so we can read the python params\n",
    "include(\"common/params_lstm.py\")\n",
    "gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS: Linux\n",
      "Julia: 1.0.0\n",
      "Knet: 1.0.1+\n",
      "GPU: Tesla K80\n",
      "\n"
     ]
    }
   ],
   "source": [
    "println(\"OS: \", Sys.KERNEL)\n",
    "println(\"Julia: \", VERSION)\n",
    "println(\"Knet: \", Pkg.installed()[\"Knet\"])\n",
    "println(\"GPU: \", read(`nvidia-smi --query-gpu=name --format=csv,noheader`,String))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model\n",
    "function initmodel()\n",
    "    rnnSpec,rnnWeights = rnninit(EMBEDSIZE,NUMHIDDEN; rnnType=:gru)\n",
    "    inputMatrix = KnetArray(xavier(Float32,EMBEDSIZE,MAXFEATURES))\n",
    "    outputMatrix = KnetArray(xavier(Float32,2,NUMHIDDEN))\n",
    "    return rnnSpec,(rnnWeights,inputMatrix,outputMatrix)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define loss and its gradient\n",
    "function predict(weights, inputs, rnnSpec)\n",
    "    rnnWeights, inputMatrix, outputMatrix = weights # (1,1,W), (X,V), (2,H)\n",
    "    indices = permutedims(hcat(inputs...)) # (B,T)\n",
    "    rnnInput = inputMatrix[:,indices] # (X,B,T)\n",
    "    rnnOutput = rnnforw(rnnSpec, rnnWeights, rnnInput)[1] # (H,B,T)\n",
    "    return outputMatrix * rnnOutput[:,:,end] # (2,H) * (H,B) = (2,B)\n",
    "end\n",
    "\n",
    "loss(w,x,y,r)=nll(predict(w,x,r),y)\n",
    "lossgradient = grad(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading IMDB...\n",
      "└ @ Main /kuacc/users/dyuret/.julia/dev/Knet/data/imdb.jl:57\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 12.264265 seconds (30.71 M allocations: 1.557 GiB, 7.87% gc time)\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "include(Knet.dir(\"data\",\"imdb.jl\"))\n",
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=MAXFEATURES)\n",
    "for d in (xtrn,ytrn,xtst,ytst); println(summary(d)); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150-element Array{String,1}:\n",
       " \"have\"     \n",
       " \"had\"      \n",
       " \"it\"       \n",
       " \"was\"      \n",
       " \"an\"       \n",
       " \"excellent\"\n",
       " \"script\"   \n",
       " \"anyway\"   \n",
       " \"and\"      \n",
       " \"an\"       \n",
       " \"excellent\"\n",
       " \"direction\"\n",
       " \"a\"        \n",
       " ⋮          \n",
       " \"for\"      \n",
       " \"his\"      \n",
       " \"next\"     \n",
       " \"movie\"    \n",
       " \"i'm\"      \n",
       " \"hoping\"   \n",
       " \"the\"      \n",
       " \"best\"     \n",
       " \"for\"      \n",
       " \"all\"      \n",
       " \"of\"       \n",
       " \"them\"     "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imdbarray = Array{String}(undef,88584)\n",
    "for (k,v) in imdbdict; imdbarray[v]=k; end\n",
    "imdbarray[xtrn[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare for training\n",
    "weights = nothing; Knet.gc(); # Reclaim memory from previous run\n",
    "rnnSpec,weights = initmodel()\n",
    "optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 15.862389 seconds (13.36 M allocations: 748.401 MiB, 6.18% gc time)\n"
     ]
    }
   ],
   "source": [
    "# cold start\n",
    "@time for (x,y) in minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "    grads = lossgradient(weights,x,y,rnnSpec)\n",
    "    update!(weights, grads, optim)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare for training\n",
    "weights = nothing; Knet.gc(); # Reclaim memory from previous run\n",
    "rnnSpec,weights = initmodel()\n",
    "optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Training...\n",
      "└ @ Main In[12]:2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 10.262852 seconds (388.52 k allocations: 74.902 MiB, 5.52% gc time)\n",
      "  9.253428 seconds (352.37 k allocations: 72.635 MiB, 6.45% gc time)\n",
      "  9.297393 seconds (353.10 k allocations: 72.646 MiB, 6.71% gc time)\n",
      " 28.816111 seconds (1.10 M allocations: 220.229 MiB, 6.20% gc time)\n"
     ]
    }
   ],
   "source": [
    "# 29s\n",
    "@info(\"Training...\")\n",
    "@time for epoch in 1:EPOCHS\n",
    "    @time for (x,y) in minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "        grads = lossgradient(weights,x,y,rnnSpec)\n",
    "        update!(weights, grads, optim)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Testing...\n",
      "└ @ Main In[13]:1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  3.205257 seconds (1.80 M allocations: 149.280 MiB, 1.25% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.8431089743589744"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@info(\"Testing...\")\n",
    "@time accuracy(weights, minibatch(xtst,ytst,BATCHSIZE), (w,x)->predict(w,x,rnnSpec))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.0.0",
   "language": "julia",
   "name": "julia-1.0"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.0.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
